import scipy as sc
import matplotlib as mpl
import pandas as pd
from pathlib import Path
from scipy.special import expit
import bluepysnap
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import find_peaks


def get_spindle_amp(sim):
    report = sim.spikes["thalamus_neurons"].spike_report.filter(group={"mtype": "Rt_RC"}).report

    times = report.index
    time_start = np.min(times)
    time_stop = np.max(times)
    time_binsize = 5
    bins = np.append(np.arange(time_start, time_stop, time_binsize), time_stop)
    hist, bin_edges = np.histogram(times, bins=bins)
    node_count = report[["ids", "population"]].drop_duplicates().shape[0]
    freq = 1.0 * hist / node_count / (0.001 * time_binsize)

    fs = []
    lengths = []
    for step in range(16):
        low, high = 3500 + step * 1000, 4500 + step * 1000
        mask = (bin_edges[:-1] >= low) & (bin_edges[:-1] <= high)
        f = freq[mask]
        t = bins[:-1][mask]
        peaks = find_peaks(f, height=2, width=3)[0]
        if peaks.size == 0:
            fs.append(0)
            lengths.append(0)
        else:
            fs.append(np.mean(f[peaks]))
            lengths.append(t[peaks[-1]] - t[peaks[0]])
    return np.mean(fs), np.std(fs), np.mean(lengths), np.std(lengths)


def fit_func(_x, shift_exp, slope_exp=4.0, amp_exp=60):
    return expit((_x - shift_exp) / slope_exp) * amp_exp + 2


if __name__ == "__main__":
    ecel = 0.0005
    spp = 0.0007
    runaway = 0.0007
    fs = [0.1, 0.3, 0.7]
    z = [-48415.99164654, 115.95393407]

    p = np.poly1d(z)
    shifts = np.linspace(0.000, 0.0005, 100)

    base_path = Path("../../simulations/08_08_Cav_variants/")
    perc = 85
    for perc in [80, 85, 90, 100]:
        var = 11
        model = "mixed25pEcel75pSppRunaway"
        sim_path = (
            base_path
            / f"variant{var}_{model}/variant{var}_{model}_SUPER_LONG_{perc}percent_CT_up_down/simulation_config.json"
        )
        sim = bluepysnap.Simulation(sim_path)
        f_mean_base = get_spindle_amp(sim)[0]
        print(f_mean_base)

        var = 12
        model = "mixed25pEcel75pSppRunaway-reducedEcel"
        sim_path = (
            base_path
            / f"variant{var}_{model}/variant{var}_{model}_SUPER_LONG_{perc}percent_CT_up_down/simulation_config.json"
        )
        sim = bluepysnap.Simulation(sim_path)
        f_mean_ecel = get_spindle_amp(sim)[0]
        print(f_mean_ecel)

        var = 13
        model = "mixed25pEcel75pSppRunaway-reducedSppRunaway"
        sim_path = (
            base_path
            / f"variant{var}_{model}/variant{var}_{model}_SUPER_LONG_{perc}percent_CT_up_down/simulation_config.json"
        )
        sim = bluepysnap.Simulation(sim_path)
        f_mean_spp = get_spindle_amp(sim)[0]
        print(f_mean_spp)

        its = np.linspace(0, 0.0012, 100)
        center = p(its)
        amp = fit_func(perc, center)
        f = 0.1
        shift = 0.0002
        it2_base = sc.stats.hmean(
            [0.0005, 0.0007, 0.001], weights=[f, 0.5 * (1 - f), 0.5 * (1 - f)]
        )
        it2_ecel = sc.stats.hmean(
            [0.0005 - shift, 0.0007, 0.001], weights=[f, 0.5 * (1 - f), 0.5 * (1 - f)]
        )
        it2_spp = sc.stats.hmean(
            [0.0005, 0.0007 - shift, 0.001 - shift], weights=[f, 0.5 * (1 - f), 0.5 * (1 - f)]
        )

        plt.figure(figsize=(5, 3))
        data = pd.read_csv("../figure_6/spindle_data.csv", index_col=0)
        cmap = plt.get_cmap("plasma")
        colors = cmap(np.linspace(0, 1, len(data.columns)))
        _its = data.index.to_numpy()
        for local_per, c in zip(data.columns, colors):
            shifted_its = -(perc - float(local_per)) * 1.90298708e-05 + _its
            plt.plot(shifted_its, data[local_per], "-", c=c, lw=0.8)

        plt.plot(its, amp, "-", c="k", lw=1.2, label="model")
        plt.scatter(it2_base, f_mean_base, marker="x", c="r", label="base", zorder=10)
        plt.scatter(it2_ecel, f_mean_ecel, marker="x", c="m", label="ecel reduction", zorder=10)
        plt.scatter(it2_spp, f_mean_spp, marker="x", c="b", label="spp reduction", zorder=10)

        norm = mpl.colors.Normalize(vmin=data.columns[0], vmax=data.columns[-1])
        plt.colorbar(
            mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
            ax=plt.gca(),
            orientation="vertical",
            shrink=0.6,
            label="CT percentage",
        )

        plt.legend()
        plt.xlabel("effective it2 conductance")
        plt.ylabel("spindle amplitude [Hz]")
        plt.axis([0.0, 0.0012, 0, 70])
        # plt.axhline(40, c="k", ls="--")
        plt.tight_layout()
        plt.savefig(f"cell_population_fraction_{perc}.pdf")
